# -*- coding: utf-8 -*-
# !/usr/bin/env python
"""
-------------------------------------------------
   File Name：     utils
   Description :   
   Author :       lth
   date：          2022/12/13
-------------------------------------------------
   Change Activity:
                   2022/12/13 17:18: create this script
-------------------------------------------------
"""
__author__ = 'lth'

import torch
from torch import nn


class ACELoss(nn.Module):
    def __init__(self):
        super(ACELoss, self).__init__()

    def forward(self, input, target):
        n,  h, w,c = input.size()
        T = h * w
        input = input.view(n, T, -1)
        input = input + 1e-10

        target[:, 0] = T - target[:, 0]

        input = torch.sum(input, 1)
        input = input / T
        target = target / T

        return -torch.sum(torch.log(input) * target) / n
